import asyncio
import socket
import os
import json

from datetime import datetime

from pylog.pylogger import PyLogger

from py_pli.pylib import VUnits
from py_pli.pylib import GlobalVar

import config_enum.excitationlight_selector_enum as els_config
import config_enum.filter_module_slider_enum as fms_config

from virtualunits.HAL import HAL
from virtualunits.vu_excitation_light_selector import VUExcitationLightSelector
from virtualunits.vu_filter_module_slider import VUFilterModuleSlider
from virtualunits.vu_detector_aperture_slider import VUDetectorApertureSlider
from virtualunits.vu_measurement_unit import VUMeasurementUnit
from virtualunits.meas_seq_generator import meas_seq_generator
from virtualunits.meas_seq_generator import TriggerSignal
from virtualunits.meas_seq_generator import OutputSignal

from urpc_enum.measurementparameter import MeasurementParameter

from fleming.common.firmware_util import send_gc_msg
from fleming.common.firmware_util import send_graph


hal_unit: HAL = VUnits.instance.hal
els_unit: VUExcitationLightSelector = VUnits.instance.hal.excitationLightSelector
fms_unit: VUFilterModuleSlider = VUnits.instance.hal.filterModuleSlider
as1_unit: VUDetectorApertureSlider = VUnits.instance.hal.detectorApertureSlider1
as2_unit: VUDetectorApertureSlider = VUnits.instance.hal.detectorApertureSlider2
meas_unit: VUMeasurementUnit = VUnits.instance.hal.measurementUnit

report_dir = f"{os.path.dirname(__file__)}/flash_scan"


async def init():
    await hal_unit.StartupHardware()
    await hal_unit.InitializeDevice()
    await hal_unit.HomeMovers()
    await hal_unit.TurnLedsOff()
    await hal_unit.TurnOn_PMT_HV()
    await asyncio.sleep(1)
    

async def flash_as_scan(pos_start=0.8, pos_stop=1.4, pos_step=0.01, flash_count=50, flash_mode=0):
    """
    Run flash excitation scans for a range of aperture slider positions and save the results in a .csv file.
    The excitation scan measures the first 100 µs after the flash with a 1 µs resolution using both PMTs in counting mode.

    pos_start: The start of the positions range.
    pos_stop: The end of the positions range. The range includes pos_stop.
    pos_step: The step size of the positions range
    flash_count: The number of flashes to average for each individual excitation scan.
    flash_mode: 0 = High Speed, 1 = High Power
    """
    pos_start = float(pos_start) if (pos_start != '') else 0.8
    pos_stop = float(pos_stop) if (pos_stop != '') else 1.4
    pos_step = float(pos_step) if (pos_step != '') else 0.01
    flash_count = int(flash_count) if (flash_count != '') else 50
    flash_mode = int(flash_mode) if (flash_mode != '') else 0

    window_us = 1.0
    window_count = 100

    await send_gc_msg(f"Starting flash_as_scan(pos_start={pos_start:.3f}, pos_stop={pos_stop:.3f}, pos_step={pos_step:.3f}, flash_count={flash_count}, flash_mode={flash_mode})")

    GlobalVar.set_stop_gc(False)

    await init()

    await els_unit.GotoPosition(els_config.Positions.Flash2)
    await fms_unit.GotoPosition(fms_config.Positions.FixMirrorPosition_c)
    await meas_unit.EnableFlashLampPower(isLow=(flash_mode != 1))

    await send_gc_msg(f"Scanning Aperture Slider Positions...")

    instrument = socket.gethostname()
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    os.makedirs(report_dir, exist_ok=True)
    with open(f"{report_dir}/flash_as_scan__{instrument}_{timestamp}.csv", 'w') as file:
        file.write(f"flash_as_scan(pos_start={pos_start:.3f}, pos_stop={pos_stop:.3f}, pos_step={pos_step:.3f}, flash_count={flash_count}, flash_mode={flash_mode}) on {instrument} started at {timestamp}\n")
        file.write(f"time [us]")
        for i in range(window_count):
            file.write(f" ; {(window_us * (i + 1)):5.1f}")
        file.write(f"\n")
        pos_range = [pos / 1e6 for pos in range(round(pos_start * 1e6), round(pos_stop * 1e6 + 1), round(pos_step * 1e6))]
        for as_pos in pos_range:
            if GlobalVar.get_stop_gc() is True:
                return f"flash_as_scan() stopped by user"
    
            await as1_unit.Move(as_pos)
            await as2_unit.Move(as_pos)

            results = await flash_excitation_scan(window_us=window_us, window_count=window_count, flash_count=flash_count)
            results_pmt1 = results[0::2]
            results_pmt2 = results[1::2]
            file.write(f"AS1={as_pos:.3f}")
            for i in range(window_count):
                file.write(f" ; {results_pmt1[i]:5.1f}")
            file.write(f"\n")
            file.write(f"AS2={as_pos:.3f}")
            for i in range(window_count):
                file.write(f" ; {results_pmt2[i]:5.1f}")
            file.write(f"\n")
            await send_gc_msg(f"pos: {as_pos:.3f} ; pmt1_max: {max(results_pmt1):5.1f} ; pmt1_max: {max(results_pmt1):5.1f}")

    return f"flash_as_scan() done"
    

async def flash_scan(as1_pos, as2_pos, iterations=10, flash_count=50, flash_mode=0, maximum=56, delay=1.0):
    """
    Run repeated flash excitation scans and save the results in a .csv file.
    The excitation scan measures the first 100 µs after the flash with a 1 µs resolution using both PMTs in counting mode.

    as1_pos: The position of aperture slider 1.
    as2_pos: The position of aperture slider 2.
    iterations: The number of excitation scans to run.
    flash_count: The number of flashes to average for each individual excitation scan.
    flash_mode: 0 = High Speed, 1 = High Power
    maximum: The maximum of the flash scan graph.
    delay: The delay after showing each graph in s.
    """
    iterations = int(iterations) if (iterations != '') else 10
    flash_count = int(flash_count) if (flash_count != '') else 50
    flash_mode = int(flash_mode) if (flash_mode != '') else 0
    maximum = int(maximum) if (maximum != '') else 56
    delay = float(delay) if (delay != '') else 1.0

    window_us = 1.0
    window_count = 100

    await send_gc_msg(f"Starting flash_scan(as1_pos={as1_pos:.3f}, as2_pos={as2_pos:.3f}, iterations={iterations}, flash_count={flash_count}, flash_mode={flash_mode}, maximum={maximum}, delay={delay:.3f})")

    GlobalVar.set_stop_gc(False)

    await init()

    await els_unit.GotoPosition(els_config.Positions.Flash2)
    await fms_unit.GotoPosition(fms_config.Positions.FixMirrorPosition_c)
    await meas_unit.EnableFlashLampPower(isLow=(flash_mode != 1))
    
    await as1_unit.Move(as1_pos)
    await as2_unit.Move(as2_pos)

    instrument = socket.gethostname()
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    os.makedirs(report_dir, exist_ok=True)
    with open(f"{report_dir}/flash_scan__{instrument}_{timestamp}.csv", 'w') as file:
        file.write(f"flash_scan(as1_pos={as1_pos:.3f}, as2_pos={as2_pos:.3f}, iterations={iterations}, flash_count={flash_count}, flash_mode={flash_mode}, maximum={maximum}, delay={delay:.3f}) on {instrument} started at {timestamp}\n")
        file.write(f"time [us]")
        for i in range(window_count):
            file.write(f" ; {(window_us * (i + 1)):5.1f}")
        file.write(f"\n")
        for i in range(iterations):
            if GlobalVar.get_stop_gc() is True:
                return f"flash_scan() stopped by user"

            await send_gc_msg(f"Scanning Iteration: {i + 1}")
            results = await flash_excitation_scan(window_us=window_us, window_count=window_count, flash_count=flash_count)
            results_pmt1 = results[0::2]
            results_pmt2 = results[1::2]
            file.write(f"PMT1 #{(i+1):02d} ")
            for j in range(window_count):
                file.write(f" ; {results_pmt1[j]:5.1f}")
            file.write(f"\n")
            await send_graph('PMT1:', results_pmt1, maximum)
            await asyncio.sleep(delay)
            file.write(f"PMT2 #{(i+1):02d} ")
            for j in range(window_count):
                file.write(f" ; {results_pmt2[j]:5.1f}")
            file.write(f"\n")
            await send_graph('PMT2:', results_pmt2, maximum)
            await asyncio.sleep(delay)

    return f"flash_scan() done"


async def flash_excitation_scan(start_us=0.0, window_us=1.0, window_count=100, flash_count=50):
    if (start_us < 0.0) or (start_us > 671088.64):
        raise ValueError(f"start_us must be in the range [0.0, 671088.64] us")
    if (window_us < 0.1) or (window_us > 671088.64):
        raise ValueError(f"step_us must be in the range [0.1, 671088.64] us")
    if (window_count < 1) or (window_count > 2048):
        raise ValueError(f"window_count must be in the range [1, 2048]")
    if (flash_count < 1) or (flash_count > 65536):
        raise ValueError(f"flash_count must be in the range [1, 65536]")

    high_power = (await meas_unit.endpoint.GetParameter(MeasurementParameter.FlashLampHighPowerEnable, timeout=1))[0]

    if not high_power:
        arming = 50000   #  500 us arming time
    else:
        arming = 120000  # 1200 us arming time
    delay = round(start_us * 100)
    window = round(window_us * 100)
    # 250 Hz flash frequency:
    loop_delay = 400000 - arming - delay - window * window_count

    op_id = 'flash_excitation_scan'
    seq_gen = meas_seq_generator()
    # PMT1 and PMT2 high voltage gate on
    seq_gen.SetSignals(OutputSignal.HVGatePMT1 | OutputSignal.HVGatePMT2)
    # Clear the result buffer
    seq_gen.SetAddrReg(relative=False, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=0)
    seq_gen.Loop(window_count * 2)
    seq_gen.ClearResultBuffer(relative=True, dword=False, addrReg=0, addr=0)
    seq_gen.SetAddrReg(relative=True, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=1)
    seq_gen.LoopEnd()

    seq_gen.Loop(flash_count)
    # Arm the flash lamp
    seq_gen.TimerWaitAndRestart(arming)
    seq_gen.SetSignals(OutputSignal.Flash)
    # Trigger the flash
    seq_gen.TimerWaitAndRestart(delay)
    seq_gen.ResetSignals(OutputSignal.Flash)
    # Start the scan with PMT1 and PMT2
    seq_gen.TimerWaitAndRestart(window)
    seq_gen.PulseCounterControl(channel=0, cumulative=False, resetCounter=True, resetPresetCounter=True, correctionOn=False)
    seq_gen.PulseCounterControl(channel=1, cumulative=False, resetCounter=True, resetPresetCounter=True, correctionOn=False)
    seq_gen.SetAddrReg(relative=False, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=0)
    # Loop to measure each window and save the result
    seq_gen.Loop(window_count)
    seq_gen.TimerWaitAndRestart(window)
    seq_gen.PulseCounterControl(channel=0, cumulative=False, resetCounter=False, resetPresetCounter=True, correctionOn=True)
    seq_gen.PulseCounterControl(channel=1, cumulative=False, resetCounter=False, resetPresetCounter=True, correctionOn=True)
    seq_gen.GetPulseCounterResult(channel=0, relative=True, resetCounter=True, cumulative=True, dword=False, addrPos=0, resultPos=0)
    seq_gen.GetPulseCounterResult(channel=1, relative=True, resetCounter=True, cumulative=True, dword=False, addrPos=0, resultPos=1)
    seq_gen.SetAddrReg(relative=True, dataNotAddrSrc=False, sign=False, stackNotRegSrc=False, srcReg=0, dstReg=0, addr=2)
    seq_gen.LoopEnd()
    if (loop_delay > 0):
        seq_gen.TimerWaitAndRestart(loop_delay)
    seq_gen.LoopEnd()
    seq_gen.Stop(0)
    
    meas_unit.ClearOperations()
    meas_unit.resultAddresses[op_id] = range(0, (window_count * 2))
    await meas_unit.LoadTriggerSequence(op_id, seq_gen.currSequence)
    await meas_unit.ExecuteMeasurement(op_id)
    results = await meas_unit.ReadMeasurementValues(op_id)
    
    results = [result / flash_count for result in results]

    return results

